import torch
import pybuda
{% if test_format %}
import pytest
from pybuda.verify import verify_module, VerifyConfig
{% endif %}{% if randomizer_config.debug_shapes %}
from test.random.rgg import DebugUtils{% endif %}
from pybuda import PyBudaModule, Tensor

{# TODO replace empty new lines with spaces to keep formatting in pipeline #}
class GeneratedTestModel_{{ test_index }}_{{ random_seed }}(PyBudaModule):
    # graph_builder: {{ graph_builder_name }}
    # id: {{ test_id }}
    # params.test_index: {{ test_index }}
    # params.random_seed: {{ random_seed}}

    def __init__(self, module_name: str = "Buda Test GeneratedTestModel_{{ test_id }}"):
        super(GeneratedTestModel_{{ test_index }}_{{ random_seed }}, self).__init__(module_name)
        self.testname = "Operator Test GeneratedTestModel_{{ test_id }}"
{% for node in graph.nodes %}{% if node.operator.is_layer %}        
        self.{{ node.layer_name }} = {{ node.operator.full_name }}({{ constructor_kwargs(node=node) }}){% endif %}{% endfor %}
        {% for constant_node in graph.constant_nodes %}
        self.add_constant("{{ constant_node.out_value }}")
        self.set_constant("{{ constant_node.out_value }}", torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }})){% endfor %}

    def forward(self{% for node in graph.input_nodes %},
            {{ node.out_value }}: pybuda.Tensor{% endfor %}
        ) -> pybuda.Tensor:
        {% for node in graph.nodes %}

        # shapes: {{ node.input_shapes }} -> {{ node.output_shape }}
        inputs = [{% for input_node in node.inputs %}{% if input_node.constant %}self.get_constant("{{ input_node.out_value }}"){% else %}{{ input_node.out_value }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %}]{% if randomizer_config.debug_shapes %}
        print(f"{{ node.layer_name }} inputs: {DebugUtils.format_tensors(inputs)}"){% endif %}{% if node.operator.is_layer %}
        {{ node.out_value }} = self.{{ node.layer_name }}(inputs[0]){% else %}
        {{ node.out_value }} = {% if node.operator.forward_code %}{{node.operator.forward_code()}}{% else %}{{ node.operator.full_name }}('{{ node.node_name }}', {{ forward_args(node=node) }}, {{ forward_kwargs(node=node) }}){% endif %}{% endif %}{% if randomizer_config.verify_shapes %}
        assert {{ node.out_value }}.shape.dims == {{ reduce_microbatch_size(node.output_shape) }}, f"Unexpected output shape of {{ node.out_value }} { {{ node.out_value }}.shape } <> {{ reduce_microbatch_size(node.output_shape) }}"{% endif %}{% endfor %}

        return v
{% if test_format %}

# @pytest.mark.xfail(reason="The model triggers a bug.")
def test_gen_model_{{ test_index }}_{{ random_seed }}(test_device):
    
    input_shapes = [
        {% for input_node in graph.input_nodes %}{{ input_node.input_shape }},
        {% endfor %}]
    model = GeneratedTestModel_{{ test_index }}_{{ random_seed }}("pytest_gen_model_{{ test_id }}")

    verify_module(model, input_shapes, VerifyConfig(devtype=test_device.devtype, arch=test_device.arch))

{% endif %}
